from nexa.gguf.llama.llama_cache import LlamaDiskCache
from typing import Any, Dict

def run_inference_with_disk_cache(
    model: Any, 
    cache_prompt: str, 
    total_prompt: str, 
    use_cache: bool = True, 
    cache_dir: str = "llama.cache", 
    **kwargs: Dict[str, Any]
) -> Any:
    """
    Runs inference using a disk cache to store and retrieve model states.

    Parameters:
    - model: The model object that supports caching and inference.
    - cache_prompt: The prompt used to generate a cache key.
    - total_prompt: The full prompt for generating output.
    - use_cache: Flag to determine if caching should be used.
    - cache_dir: Directory where cache files are stored.
    - kwargs: Additional parameters for model inference.

    Returns:
    - The output generated by the model.
    """
    temperature = kwargs.get('temperature', 0.7)
    max_tokens = kwargs.get('max_tokens', 2048)
    top_p = kwargs.get('top_p', 0.8)
    top_k = kwargs.get('top_k', 50)
    repeat_penalty = kwargs.get('repeat_penalty', 1.0)

    if use_cache:
        # Initialize disk cache with specified directory
        cache_context = LlamaDiskCache(cache_dir=cache_dir)
        model.set_cache(cache_context)
        # Convert prompt to tokens for cache key
        prompt_tokens = model.tokenize(cache_prompt.encode("utf-8"))

        try:
            # Try to load existing cache
            cached_state = cache_context[prompt_tokens]
            model.load_state(cached_state)

            output = model(
                total_prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                stream=True,
            )
        except KeyError:
            # If cache doesn't exist, create it
            model.reset()
            # Run initial inference to populate cache
            _ = model(
                cache_prompt,
                max_tokens=1,  # Minimal tokens for cache creation
                temperature=temperature,
                echo=False,
            )
            # Save the state to cache
            cache_context[prompt_tokens] = model.save_state()

            # Generate output after creating cache
            output = model(
                total_prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                repeat_penalty=repeat_penalty,
                stream=True,
            )
    else:
        model.reset()
        model.set_cache(None)

        output = model(
            total_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repeat_penalty=repeat_penalty,
            stream=True,
        )
    return output